-
Notifications
You must be signed in to change notification settings - Fork 497
[WIP][DSV3] Remove keep a copy of GroupedExperts weight, free memory in StateDictAdapter #1585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had some comments. Need @fegin 's input as well.
@@ -61,6 +61,7 @@ python scripts/checkpoint_conversion/convert_from_hf.py <hf_checkpoints_dir> <dc | |||
Some limitations: | |||
1. It can't be used to convert HF checkpoint on the fly using GPU DTensor, because of sharding and quantized blocks may not be aligned well and causing silent numerfical incorrectness. | |||
2. It can't be used for weight sync to generate a state dict of bf16 because fake quantization to fp8 is applied. | |||
3. When converting GroupedExperts weights from HF separate expert weights on-the-fly, `torch.split()` will cause huge GPU memory usage. This is because torchtitan GroupedExperts' weight has shape `(num_experts, dim1, dim2)`, and by default shard FSDP on dim-0. When we call `torch.split()` in `to_hf()` function on dim-0, this will incur and all-gather and get replicated expert memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought more about this. Even if FSDP shards on dim-1, EP will shard on dim-0 anyway. So the problem still exists. Let's discuss next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we perform a redistribute() before split() to ensure the expert parameter is sharded on dim-1
? This redistributed, dim-1 sharded parameter will be used exclusively by the split()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With EP it's sharded on dim-0 anyway. Performing this redistribute means at least 1 comm in to_hf and at least 1 comm in from_hf.
If both EP and FSDP dim-0 sharding is used, we'll have strided sharding whose redistribute algo today may not be efficient or even correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The redistribution algorithm should be correct, but whether it is going to be efficient, that's debatable. I think it will be more efficient than allgather
as less communication is incurred even if it is not the optimal one.
There will should be no extra comm in from_hf as DCP.load will handle the resharding but this resharding can be slow for sure.
@@ -158,6 +158,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: | |||
new_key = new_abstract_key.format(layer_num, expert_num) | |||
hf_state_dict[new_key] = split_values[expert_num].squeeze() | |||
|
|||
# Remove the GroupedExperts' weight from the state_dict to free memory | |||
del value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for loading checkpoint synchronously, this sounds fine.
But for saving, after calling to_hf
we may still need the original weights for next training steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, that's a valid concern. If a user periodically save a checkpoint in HF format, this would be a issue. I checked checkpoint.py
, and it only support last_save_in_hf
in _save_last_step, and we are not supporting saving HF in between
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The adpater is independent of checkpoint.py in torchtitan. In RL weight sync, it will be called without checkpointing.
Context
Test
FSDP=8 (FSDP shard dim-0), num_experts = 256
FSDP=8 (FSDP shard dim-1), num_experts = 256